{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Task02 - 数据读取与数据扩增"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data as ud\n",
    "from torchvision import transforms\n",
    "from torchvision import datasets\n",
    "from PIL import Image\n",
    "import cv2\n",
    "import glob\n",
    "import json"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型数据准备，数据扩增"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 进行数据扩展\n",
    "\n",
    "class SVHDataset(ud.Dataset):\n",
    "    def __init__(self, img_pattern, label_folder, transform=None):\n",
    "        self.img_path = glob.glob(img_pattern)\n",
    "        self.img_label = [v['label'] for k,v in json.load(open(label_folder)).items()]\n",
    "        self.img_path.sort()\n",
    "        self.transform = transform\n",
    "    def __getitem__(self, index):\n",
    "        \"\"\"\n",
    "        实现了切片方法的获取\n",
    "        \"\"\"\n",
    "        # 批量读取数据\n",
    "        img = Image.open(self.img_path[index]).convert('RGB')\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        # 将原始数据分类10为0, 保证有五位数字\n",
    "        # example: [10]*2 = [10, 10], [2, 3] + [10] = [2, 3, 10]\n",
    "        lbl = np.array(self.img_label[index], dtype=np.int)\n",
    "        lbl = list(lbl) + (5 - len(lbl))*[10]\n",
    "        return img, torch.Tensor(lbl[:5])\n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据扩充和训练规范化\n",
    "data_transforms = {\n",
    "    'train': transforms.Compose([\n",
    "                # 缩放到固定尺⼨\n",
    "                transforms.Resize((64, 128)),\n",
    "                transforms.RandomCrop((60, 120)),\n",
    "                transforms.ColorJitter(0.3, 0.3, 0.2),\n",
    "                # 加⼊随机旋转\n",
    "                transforms.RandomRotation(10),\n",
    "                # 将图⽚转换为pytorch 的tesntor\n",
    "                transforms.ToTensor(),\n",
    "                # 对图像像素进⾏归⼀化\n",
    "                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
    "            ]),\n",
    "    'val': transforms.Compose([\n",
    "                # 缩放到固定尺⼨\n",
    "                transforms.Resize((60, 128)),\n",
    "                # 将图⽚转换为pytorch 的tesntor\n",
    "                transforms.ToTensor(),\n",
    "                # 对图像像素进⾏归⼀化\n",
    "                transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])\n",
    "            ]),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x2351f1de208>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAT0AAACPCAYAAACI7gxXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO19aZAlV3Xm972tXq29aHNro8GWZWmwFmCEMNgWCNmyxiD/sD1oxljMyGiYwAHM4DACxh4wXjA2WCZmvChYxAhCGCSMZA0M1sj0ODAOoCVLWEJoYZBQo1Z3a+nu6lrfcvwjsyrPPfny1qvXVe+VneeLqKjMvJn3nrx58757vjwLRQQOh8NRFlRGLYDD4XAMEz7pORyOUsEnPYfDUSr4pOdwOEoFn/QcDkep4JOew+EoFXzSi4DkHpK/stHXknwXyY8cn3RbCyTPJblX7QvJOZK/M0q5HBsLkq8meYxkl+Sr02MfIvmmUcvWL0ox6ZF8bOUBbQWIyO+KyLon0612HwbvA/CH5tj5IvLuXieT3J1OjMfU329shCAkX0jySySfJrmhhqgk30OyZeR+wYB1/SLJr5KcJ7mnR/kFJO9Oy+8mecGA7TRI3pKOHyF5iSknyd8n+Uz69wGS7FWXiPxfEZkC8D11+A8AvJtkYxD5ho1STHqOzQPJGsldAF4J4PMDVLFdRKbSv/dtkFgtAJ8BcM0G1WfxF0rmKRH5/wPW8yyA6wG83xakE8htAD4JYAeATwC47Tgmlq8A+CUAT/UouxbAzwE4H8B5AH4WwH/qt2IR2Q/g2wBeO6BsQ0WpJz2SO0jeQfIQyefS7dPNaT9I8uskj5C8jeROdf3F6S/1YZL32V/QSLvvIfnJdLtJ8pPpL+xhkt8geUqPa24CcCaAv0pXF7++lgypiv0+kn9HcpbkX5M8ca12SZ5K8naSz5J8lOQbjey3pNceBfAGAJcBuEdEFvu5/82GiDwkIh8F8MCoZYkhXTV9BsCTPYovAVADcL2ILInIhwEQwKsGaGdZRK4Xka8A6PQ45WoAHxSRfSLyfQAfRPJc14M9AP7NemUbBUo96SG5/48DeB6SCWUBwP8w5/wygP8I4FQAbQAfBgCSpwH43wB+G8BOAL8G4FaSJ61ThqsBbANwBoATALwplSOAiLweiUrxmnR18YE+Zfh3AP4DgJMBNNJz1mr3ZgD70nv+eQC/S/JSVeeVAG4BsB3ApwD8KICH1nnfK3ic5D6SH1+ZkP8Z4DXpD8IDJP/zJrXxrwB8U0I/0W+mxzejrfvU/n0DtPMgkpXilkepJz0ReUZEbhWReRGZBfA7AH7SnHaTiNwvInMAfgPAL5KsIlEVviAiXxCRrojcCWAvgCvWKUYLyaTzQyLSEZG7ReRon9f2I8PHReRhEVlAovKt8EI92yV5BoBXAHiHiCyKyL0APgLg9arOvxeRz6dtLiCZ/GbXed9PA/jXSH5wXgxgGskEutXxGQDnADgJwBsB/CbJqzahnSkAR8yxI0j6abPbOgJgqojXK8AsknGw5VHqSY/kBMk/J/l4qqr9LYDt6aS2gifU9uMA6gBORPKy/kKqGh4meRjJZLFrnWLcBOBLAD5N8smURK73eW0/MmgOZx7JAI+1eyqAZ9MfgRU8DuA0ta/7BACewzpfRhE5JiJ7RaQtIgcA/CqAnyI5s556hg0R+ZaIPJn+UHwVwB8jWQ1vNI4BsH0xg/X/uAzS1gyAY2aVuRamARzeUKk2CaWe9AC8HcDZAF4qIjMAfiI9rn/hzlDbZyJZIT2N5MW/SUS2q79JEcmR0jGISEtE3isi5wL4MSQk8i8XnW72B5Yh0u6TAHaS1JPYmQC+H5HjmwB+eK021xIp/b+e1cVWgGBzZH4AwHlmtXUeNoenfAChanr+AO2cg1BF3rIo06RXT8n7lb8akl+nBQCH0w8U/73Hdb/ExAZtAsBvAbhFRDpIvqq9huRPk6ymdV7S40NIFCRfSfJH09XlUSSTai+yGQAOANDmEQPLUNSuiDwB4KsAfi+t7zwkX0FjquedAF5EsrlWu6r9l5I8m2SF5AlIuNI9ImJVupXz95B8T591M5Wlke43SY5Fzn+M5Bv6rPtKJh/ASPIiAG9B8pW117krZjm7C8qrqZw1AJVUzpVV/h4k4+AtJMdI/mp6/G8K6rqR5I0RucfU82mkba1MqP8LwH8leRrJU5EsBgrrKsBPAvjiOq8ZDUTkX/wfgMeQ/CLrv99GosrtQbK8fxjJZ3oBUEuv2wPg9wB8HcnE8FcATlT1vhTA/0NienAIyUeFM9W1v1Igz3sAfDLdvgrJR4A5JJPah1fa73HdlUg+ZhwG8GvrlQHJF7mvrNUugNMB3JHW+R0Ab+olu5HtswD+rdoXJHxh0TO5CsB30/b3I3nxfkCV/xmAP1P73wFwWZ/Pe3eP5/2YKv8igHel2w0kKuOP9Fn3zQCeScfMtwG8xZQfA/Dj6faPp2OvXlDXG3rIeaMqvxDA3Uh+mO8BcKEqexeAL6r9uwC8cZ3vwO60jAA+kD7vZ9Nt9ronU9+r0+1dSD58NUb9rvfzx1Roh+O4QPJcJLZkF4mIkFwEsATgwyJyXEbH6cr1syLysg0Q1db9CgBvFpEN/xhB8r8BOCQif77RdZt2GkhUy/NEpLXJbV0K4FYAYwCuEJEvk/wggO+IyJ9sZtsbBZ/0HA5HqVAmTs/hcDh80nM4HOXCcU16JC8n+VDqqnTdRgnlcDgcm4WBOb3U1OFhJH6X+wB8A8BVIvKtomtq1ao06r3tbrUcVqJYWcxESpdocydWwmvIbO6vmLKwbSkss1Lo9iqVSs/jx4OgHlNnUQsVWxJUYe5b3auWPyqHrTP6bIITQ8SGZJ/193ueSLe4jogcm0GF91tlV+w47L3dNbcWXGW6pFarrW43x0PLo0Yji3FQq1aDMqqxoUeJHRfhfp8PGOE4FHNDnXZm2fXoow8/LSJ9uYDW1j6lEBcBeFTSCBMkP43EpKJw0mvU6zjrzOf1LFvutFe3O52wU1pBWWjC1lW3YDu6qh5QrZY9kkYzNNlqNrOH3GiEXaIntuV2+GGs1VrK2mI4MYw1s8l9fHx8dbtuJ30lc9eO0uC08N7qapBWzUCsFUyydSOjvs5ObMJMlsZk2F96oNuXQNdZr0aeDaqFZdLpFpah0vvHxO6HPzTmRVXRptrqGQLh87bWktLN2paOnXiKJtbwmerT7ITb1WPBTGx6aCy1wrKl5UzQVlsfD/snuMz8uO885YTV7bPPCe3MzzwzM/vcuWNHUDauJshGJevnRi3s84YaC2I7NrIoaLezG2othc/quWcyB5DXXPHKx9Enjke9PQ2hO9I+hK5KAACS15LcS3Jvu1Nkc+twOBzDwfGs9Hr9tOXWrSJyA4AbAGBivClSTU6xqxo9++vZHQDaHb1vViuV7LoKw9thsBLItrtm7m23tCxh2xqdtlleq+vqTdO2WsmABduJNFlRRPXNl2VdXa2GfVKvZeqIVtfHqqGM9bpelYVyddRKo02zAg32i6kCjdyqLNgzw6Za/FtctJpLDtR6ltm+q6j2xqdDl2GtSXRb4X231IqqY1YrwXXBZebeNM0SUQErpix8X8y7oymYQL01K8mIfqtPtZqWXsXm3tsCPd8ufDtBP4TPjaL7xD57tS9hmX2P+8XxrPT2IfRLPR2944I5HA7HlsHxTHrfAHAWyeenFuGvA3D7xojlcDgcm4OB1VsRaadO0F8CUAXwMRHZ0pFqHQ6H43g4PYjIFwB8YT3XrHBMlgoITUpMYUefF5YFXwtrYfoA1nqbx+S+VGpOpGsXv7rMcjCaBzM8RQF3k+fmYmXFMjca2b1NTEwEZRPN3l+Lx8fCr7DjY1l/NWrhUNCc3hLCr9b6tmsV8+VY1VOrZG3bL8zVaDSm/kxkxHyBrKr2dH/VKvbZZPc2Zr7W6y/yrcWQNFpYyAJaLy+E3O/y8rKqI7suCciTod0t5ow1LFfWr2mZPq/bNV+Ao5xextXlLCRUWV4uXb9Ezovx1YVF5ryYGUz/cI8Mh8NRKvik53A4SoXjUm8HQ7JUtmqqNq2wqlA9UCvDebpWVdbiRr2t1JVBZMTwXn8yt8twvYKuVKyRqy40lVasaUqPawCINrOwlhtV1SdGRdMG1Va9nZnKzDC0YfSUOW9ycnJ1u2GogHY3U2nZMOqhkqVu1OJagdG0Nd6uSLFJSTeitcSoAm2yRGX2Us0ZJ+vGloOy1nKW0G1xLizTVMFcPneTqjNTkY31B7oRE6XAWDlC/+QHW+86rfFzoHFaBiliQytRtbVI7S6W0Y7zfumM3FW9X7E14Ss9h8NRKvik53A4SgWf9BwOR6kwZE5P0JXkk303574UcyTXO1VTVswNUc/p1LyE5VK0m01oUhCLMBJEWWGxWUSMh7ImH0FZtZjn1PXkAg40sv0xZabSnJgMzptQ3F+9HpqzaFOLSsO0XcvabhjXtjDIg+LYcp5mkX4tLLF9aftE8XhKjsAlEEBVyb+8NB+Wqevaxu2wtqg5w2I+ThR/1e2Y8aTGvb3PqhqvQhuMQPdXWBb2iS6wYVaK+7yqpKHhAhnhIYspN0uiZ/3aMXdeiZjj5Pm/DLUCl8e14Cs9h8NRKvik53A4SoWhm6xk5g7FwTpzQTGrWoUtnqfbZlne0BEhtKV6ZDltA06KkqVmzTOCeHRhPYHK2ZXe2wBqKtJJLuacqnNceVkAwPhktr99+0xQNjaWlS0rU4SldmiWMN5V5iUIzX1WouEAQMeoSfpWm/XQDEY/u2o9q5MmJMbiYmYa0hwPVeuGUsm7RhXSphU26k0QCk8Fj6ubqC3sZDLOL4VmKVotE3tdPdufngmfx+REZkI0N3tsdfvw4aPBea1WZgpko7i09dgwqntVmWbVa8bLQ+22VMxHS81Q9Z2lBuoq3qR9w2JmI/p+xtRYzumlap9da/aiPFg6Rq1XMlfMGGJOhe4PvtJzOBylgk96DoejVPBJz+FwlApD5fQE2edqS6vpqB72k3ZQh7mwo2wh7KfvIveZnOuMtmyxthVBBORiDqFaDd24tLlJwP1ZfiTGL0YSFp1wQpbTYNu2MPqvNutZmM+4s9m50HVqdv6pQjG0K9jEdMj3Tc1kpi/aDRAITV86c1nbOkIJAMwvZLzX9HRoSjM5PbW6rV0JLRaXw+gvi7OZ+9fSknKj61g3uuw5NqeMe18ju/HmRMjbNRVHaZPUyHJmmqJd7pZMXoeuSmKRc/zqFr8DOrK45TmLIs/kTb8Ur2Y5aG3uY5MGRcZh/n1ZGzbgSjWg820uFWVOluMWndNzOByONeGTnsPhKBWGbrKysrS1AQ6j6qcU7qCi6qFVCYrMuXNL8vUk5OldZqOIFHkm5Lwn1HU5S3WlnlTrNqlPpk7n0uapz/xzSqV97shccN4xpX5qdRAAuszqeMFZZwRlWpNvz4SqtQ44uXA083Z4+umng/O0unvKD4TpSme2ZWkGaybZ02IrMzF59tnDQdlTTxxY3X7u2cxUpL1k1aBMxXzxxS8MSsaUeq5NfwCgMqaet6myo0xf5o9lKm29EtIesfGkx33HmGeIVvMKa7CJk4ypkRqi1hOoaLwmZcXJmPT+oDm0Y9ioPNEavtJzOBylgk96DoejVPBJz+FwlArD5/RW9f7+dfWQ7zMRIGJhGAqQ4wmKolQY5COuqITYJkG5SO+kOJqLA8KoGzTmAJpbaTQapkxFLzZ1VlU0jWYz45rGl23y5+y88XHT9lgm88xM6OY2PZWZmNi2l+ez9mZnZ1e3Dx06FJyn3Ze0+Q0ALCxknFilEpp8zM5nXKB18VpU12nTmenxkJur1bU5Udiv1XrmTmbNkEIq2JisKBsfzWvacRFDGGkoLOsE7RWbbvTLQduxrKN059zX+kxQPijiibM2ti3AV3oOh6Nk8EnP4XCUCsNXb7sreW8H+7y9vqV3f3VoSfKar060Yi3hs22d9xQAmh1lva+vyakV+ryOKVNmBMbEQKuVYyafLZSZh3YcqDXCiCg7lWl8vR6qeWMqakhzMmx7+/bMTGXSRH+Bim6i883aiCJ6P5cTV+3H8qdWjTnI9u07V7enJrevbm+bCtXzus51WwvVZ22m0hwL+6SrwpmwYj0yVKQQVdTpFEeCsZFOwn6wmYG0OUuxuVe32+l5PHdeZJzHVFhbZ1FO3EHz9lr0q/quB77SczgcpcKakx7Jj5E8SPJ+dWwnyTtJPpL+3xGrw+FwOLYK+lnp3QjgcnPsOgB3ichZAO5K9x0Oh2PLY01OT0T+luRuc/hKAJek258AsAfAO/ppcEUvt8l5YulgKiohdj7Bc38uMhq5ZNLUn/yLOcMekq1u2WTJRVyHRcBZmMcR3nfxvVm3IbJ3f9nuoBTLv3Asc1ljJeS2Worv69ZMJGjVdrOR8WNV404238pc4Gy4kYlmZhLTMSZKiypKSc78R1QUamXKZCO8HD6StT2zLeRDWyqKM8dCzq3RUO21QqE7dRXVRfW57Vdt2pLjj3WEFBtJRUd0Nn2ieTw97ro2EkxkHOZex6Lzcpye5hMjbUcinm+881ocg3J6p4jIfgBI/5+8cSI5HA7H5mHTv96SvBbAtQBQi8RGczgcjmFg0FnoAMldIrKf5C4AB4tOFJEbANwAAOPjTVnRLWmX2n1+Ms+ZrOjP/GbdWhQBgrReHcW5c0MzkkiOWqtiqnO7gQlDcRBUG5AxpiJr1dTmrK1WtSdHZkrTMTpMRyXImZsPVcC5uUy9nV2wKn92bw2EJis6qki91lQlpm2lHto+OXYsCzBqo8vofsglalL911DXtRaNV8fskdXt5ZbJo6yS/4zXm0HZmKrTRtUpet5LRg3WQT5tkigNS8FIp1j/LDIVyau3va9ZC3FTlH7Pi6jWkfNiWrcMqKgOqt7eDuDqdPtqALcNWI/D4XAMFf2YrNwM4O8BnE1yH8lrALwfwGUkHwFwWbrvcDgcWx79fL29qqDo0g2WxeFwODYdI/iykHJ6ljsTbTYS8iz9usj07aJm+RJdZKMq9xnpNh9ttjf/Z00YOipCcc49LojcEdYzN5dFJd5mElZXJzJZ6orfq06EJh5d9fi7hlebn8/qnzsS8n1L05l5Rns6vK7R1KY0malLLmKJesaagwTCfrX9OKHMZWJJZLS5zJJKjgQAgkz+gwcPBGWTM1lSosXJ8DpK1najFj6rts7n3mdiK2smEuO2ND/XaRfzZYPwb732+0X/bUfq7/Md26jIzO6G5nA4SgWf9BwOR6kwfPU2VWvyEUu0mhqWVVWEkZzXBXVOWeu1kG3HrcV7XwMA7DPYaUy1DhL8GDXP7ptalZD23rJHZ1f99VpmwrJzZ6bmzS2E6lpLmUFMTU0FZfsPZFZI8/OhycfSUuYV0e2aDpNaz+1G1QTyVNFlFhdDuUJvk7B+rdpNTJicu3qcqL5j1+QkrmX3bYN8anOZhemwT6YmVL7fWqh2L+jxpRiMmjEn0vl+894+vQORAiEtYmXWZRXVB9Y0S8uco1IiqmOgWht6ptGo9Twv947pnLsbluxnuCYrDofD8c8SPuk5HI5SYcjqLVXAgf6c8JOTi8s2I8jgIIh9bdMrfSuj/uqbz1ugg2lab4pOz20AWF7K1FEtVcd4B9SUal2ztIFStQ4/F+aiOOnETB1tL4WqVqumg1ii5zYQ9kmFNtBCsXqrYpTm85Kobc101AyDUG9k/drqhvl+RXo77wPAksq5KxKqt61OJktb6bc6F0hynQ4GGtbfVm23uzaAhc7vHEKPE0rxOmYjrCAsNuPr6mbDV3oOh6NU8EnP4XCUCj7pORyOUmEEHhnJPJvn7WJBRCNchP4UbnPg6k/2etsEYewl33pho4Gw2jvCS+66arXnNgA0lKfCmCGmxpRZivZ8SOpRpg+KC9SmGgBQCTwhQq8OzZdZXirgkMzz0J4dOvKI9T7QfFmuDs3OVcPn0WB2r3YI6cgn1cBeyUQpheYCw7FQiUTcCTm+8FlpblbzkBXT56K6MmeW0tV8n5U5EjRWm5R0B+PtYvetsVGeHP3WD/fIcDgcjuODT3oOh6NU2DKhjGNLar30Xs+yvOgTfT5QaJ85LCIyWtVUO9RrkxKbH3epmpmX2Py1VCqtVWn0udumpoOyxnimAur2lpaNiYRSYRdU0FAAWFrI9qenwny5k1OZd8X4eCizfh6ddta2VeVarey+l5aWTFmoTmvovBia9kga1Gpetr1g6tfBFEw64UA11TQB0CMYRSCXCqw6rgITNELqQd9ry6iwepx0LQUTjHMrR3/jPF5WTFlobLZ6G2tvPTk/YvCVnsPhKBV80nM4HKWCT3oOh6NUGKHJSnhU70Zz20YiqeQDk/bW+XN8RqQsBn1uy7h4VSoq8oXi+6wlwvJSS5WF8laUnYfl+3SAUctzau5Dc3oLcyGfOK8S5swZTk87dZ140o6gZGYmizZSb4T91VLttdtZ/TbiBxQft9gKo6wsK5eu5XbI79WCXMOGo1RcoE4GtKACrgLAvAq6OjEeRlJpNjM+zvKogYuahC5w3U7Wtk4mpYPEAkBbIpFIYslzolFQsrJWO2aO1R+GEXy0f1QKtuMud/3W6HA4HP/i4ZOew+EoFUZmsrIeFTOwJDcmBv1a0AemFCYqSexTfr8mMdbMojmWmXnooJ7WhKGuIp3kzTqyOm2gTW36MG9zQCjPFK22LpuIKLqOqvEcOPnkE1e3p7eHJjETk8orwowgpdUHdU5vC9XIkzonrG5PTk4GZbpf2+3iPul0wj5fXsz2FxZUXg+jBuvnMXPy9qBsekYHCi2mDRYXQ5VZ9+Ux1ed2XMS8Iqr1rMxmWBHqvLpWhdWRbSIeDBHzj8BcJqJ2D6reBu9YscVNVOaNgq/0HA5HqeCTnsPhKBV80nM4HKXCyCIn50pivFq193n564zJR5AYKCuzvGBX8QbriUwRQ1EUEZvjVZtF5LmUbN9yQzqBzZEjR4zMKlqH4mpYCdvW/OLkWMg17jgh47oakybCSzPrQBuVWLNRVBFRdp16SnDWjOL4pneEvNr4eObmZp9HW3GGXRNDWKj510yOeqMZnLd9Z8bpTU2GpkDanazbCTlQbf6zvFjs2jY7O7u6vWC42LbmF22UmGpx4p6uWp/EojHrMZTj3/SYNLRgLKlPwI0biq3fN0Kb8eTngOJawkjWJlK2vYk+4Ss9h8NRKqw56ZE8g+SXST5I8gGSb02P7yR5J8lH0v871qrL4XA4Ro1+1Ns2gLeLyD0kpwHcTfJOAG8AcJeIvJ/kdQCuA/COtSorDFShVrzVaBQJmy+3WDUt+pxuj1f6TS7UNW1Xi38zirwiKka3rtpIIUEdWk0NVZq5+UyFOnYsNPnQ0UGqjUz/HK+Huqi+1+ZEmJd2TEVqYaPYdMNCa2UVpSpObQvNXsZV5JbxiTCKS0W1V6la4w3dVqiSV6ASIik1eLxRHAmmWjGUQrvYTEirsIvzC0GZNg3S1MPSUvjcdGSbqvEu0jmQY5Ya3Vao5rVVwqLWciQnsY6kYjyDtIdPJ2c1ok1WjDDKK0KfZ19hKlkq+UK1Ez5vJRY6uWhIm6Teish+Ebkn3Z4F8CCA0wBcCeAT6WmfAPBzA0ngcDgcQ8S6OD2SuwFcCOBrAE4Rkf1AMjECOLngmmtJ7iW5t2OIYYfD4Rg2+p70SE4BuBXA20Tk6Frnr0BEbhCRl4jIS+yXS4fD4Rg2+pqFSNaRTHifEpHPpYcPkNwlIvtJ7gJwcM16ANTYe57VVJflOoLIycV5wHO8V8jPZfp/1dgKdDQX0bV8g+YMQ75B80bddnhda0lzN5lcNsKyJles21ND2YO0TVQPHQ34qUNPBWXarUtvW76yqUxDbOLs9lwm1zhCzlD/Vi4uhJFbbB8Vtd1oKFMdc29oZfW322H9uu2Oua6i+KAJxRM2zI+tHl/dtjEp0dyWeabazW3uaBiV5ujRbB2wpPqkVgl5VJ30qFHN2fsouWyy76xtWrc61bVd1f+VqonarKq01h7aM3PRRthWfJyNPBMkttLvrXmPKjpZeXSpFbath1N9zExXoSdg3+jn6y0BfBTAgyLyIVV0O4Cr0+2rAdw2mAgOh8MxPPSz0ns5gNcD+EeS96bH3gXg/QA+Q/IaAN8D8AubI6LD4XBsHNac9ETkK8jZjq/i0vU0Rpp8pApabbXW6NospWKStUhwnlE/CzwtrFVKTd2emEK9ZI8lJYIJaBjmQc22rTdIcG+RxDNWH9EW7taERKuq2uzCRpdpafOJaqjmadjoLBLJQ6z7RPe/7TvdJ9b0J+74UpxPWMuiZbSUSk2Vie2TpawfliMJhaw5izZL0gFlbZ/XK8UJeAK127xyej+XS1eblNSzsZBjBvQ4t4mIC+rLlxXnq9ZBPSuRAJ+5suAl7j+QqsTelwjcI8PhcJQKPuk5HI5SwSc9h8NRKmyZKCtasc/xRMqUw37u1hxhV+xnfhVVQhNFbRvNhL3PQzzRuJazZkxR9L6WMWeOo13suuF9U4ch7pgbV2YLYpLntJeyenRkilarOPE0rfuPeh7tQ8+EcvXptheU5Xi74n61pjshtMmKMW9gbw7Rmo3o5hrGBKOr+rIViTS9bDi9hQVVpuqwpkD6+UvOBVE9U+Ny1VVl1vmq36T2RddYWJOuWOTkomhCFkF0lmjuLdt2sauZJ/t2OByOPuCTnsPhKBWGrN6KWpLmFunZWUZNDT+TW5sPnZQz0nIkIEMQuaVSbAYRD2AaaTuSkKWqzSxoVYfebQGh90m7XWw2UlGVdGxe3YjJjZZz1uSNjSkVRf0VM/fJ0Qa1mP5TbLKix1QQScV4PmgTlrGGUW91ME3jkaFV1fZyqN6GCYt6J+oBQrU1F6wzkiAnFhy0q5IpB6YuRn0OLLpyplO9vZfsfj7KUe9103ryR2tsfh5dX+k5HI6SwSc9h8NRKvik53A4SoWhx3pa4QSs6q5NJPKFxbxazJ0GQZniUiKXMMfBaH7GXsiCbSDglxShmPOc6eqsR6Ys4F2seYbmQIv5n+Calk2sohsMf//aqg4bGSbGuxRxeraOqGlF5AF1WWw+gQLzH8tD6WTo2tQk1/fPs98AAAaHSURBVJaJdBJDEX+ZdzXLtqOmObno3opr7NiMQtr9K2ayIgXbIScd564H4+r0I7V3HbxXkfrFRuMZEL7SczgcpYJPeg6Ho1QYXShj+zk9+Cxu5+JiM4W4qqVymAaarvXI0N4gKCyzVvK6zJpFBCqOUqdyuXNVGY0KGCTZiXiDRKO/FMhr93ORWtR+rVocCDMW6SSWx7USMcsPWASragXPY3NhVbnAnMWcW9TnOapBik1PYqqjznvbZYTq0FZIxvQrtO4yZmFQXj20wVnVu2l6Xd+2phFs/fpdFJs8N+6iseHwlZ7D4SgVfNJzOBylgk96DoejVBgqp0cS9bEk0kdnHdycJg4sHxAz3Qj4oIBiyH00X93qWp5F815S7EJmzSJifFwoooqIkrPc0KYPNrxMrfc2Qv4kVn8gr5Ur4PtCjicwL7KVBu0Vmzp0IgnDc9mfFPQzjtahr+kWc2BVE2WlX3bJjjUti27OcnpUPJu1gAp4r5wrZjG3HHKs6rgx8dDjt9YN77uinjHN86Z+ebphRB/t5lhRo8jywDoBlsDy+WrbdopysbPmXpUBzWd8pedwOEoFn/QcDkepwGFENVhtjDwE4HEAJwJ4emgNF8PlyGOryLJV5AC2jixbRQ5g68iyIsfzROSkfi4Y6qS32ii5V0ReMvSGXY41sVVk2SpyAFtHlq0iB7B1ZBlEDldvHQ5HqeCTnsPhKBVGNendMKJ2LVyOPLaKLFtFDmDryLJV5AC2jizrlmMknJ7D4XCMCq7eOhyOUsEnPYfDUSoMddIjeTnJh0g+SvK6Ibf9MZIHSd6vju0keSfJR9L/O4Ygxxkkv0zyQZIPkHzrKGQh2ST5dZL3pXK8Nz3+fJJfS+X4C5KNteraQJmqJP+B5B2jkoXkYyT/keS9JPemx4Y+TtJ2t5O8heS30/HyshGMk7PTvlj5O0rybSN6d/5LOlbvJ3lzOobXPUaGNukxcSz8nwB+BsC5AK4iee6w2gdwI4DLzbHrANwlImcBuCvd32y0AbxdRM4BcDGAN6f9MGxZlgC8SkTOB3ABgMtJXgzg9wH8USrHcwCu2WQ5NN4K4EG1PypZXikiFyj7r1GMEwD4YwD/R0R+BMD5SPpmqLKIyENpX1wA4MUA5gH85bDlIHkagLcAeImIvBBJMoXXYZAxIiJD+QPwMgBfUvvvBPDOYbWftrkbwP1q/yEAu9LtXQAeGqY8abu3AbhslLIAmABwD4CXIrFur/V6Zpssw+lIXp5XAbgDiR/60GUB8BiAE82xoT8bADMAvov0Y+MoZVFt/xSAvxuFHABOA/AEgJ1IAqXcAeCnBxkjw1RvV4Rewb702ChxiojsB4D0/8nDbJzkbgAXAvjaKGRJ1cl7ARwEcCeA7wA4LFkGlmE+o+sB/DqyoMQnjEgWAfDXJO8meW16bBTj5AUADgH4eKryf4Tk5IhkWcHrANycbg9VDhH5PoA/BPA9APsBHAFwNwYYI8Oc9HrFgSmtvQzJKQC3AnibiBwdhQwi0pFEbTkdwEUAzul12mbLQfJnARwUkbv14VHIAuDlIvIiJDTMm0n+xBDa7IUagBcB+FMRuRDAHIanVueQcmWvBfDZEbW/A8CVAJ4P4FQAk0iekcWaY2SYk94+AGeo/dMBPDnE9nvhAMldAJD+PziMRknWkUx4nxKRz41SFgAQkcMA9iDhGLeTXAm2Nqxn9HIAryX5GIBPI1Fxrx+FLCLyZPr/IBLu6iKM5tnsA7BPRL6W7t+CZBIc1Tj5GQD3iMiBdH/YcrwawHdF5JCItAB8DsCPYYAxMsxJ7xsAzkq/tjSQLJVvH2L7vXA7gKvT7auR8GubCpIE8FEAD4rIh0YlC8mTSG5Pt8eRDKoHAXwZwM8PSw4AEJF3isjpIrIbybj4GxH598OWheQkyemVbSQc1v0YwTgRkacAPEHy7PTQpQC+NQpZUlyFTLXFCOT4HoCLSU6k79BKf6x/jAyLBE2JxisAPIyEO3r3kNu+GQkX0ELyK3oNEt7oLgCPpP93DkGOVyBZgn8TwL3p3xXDlgXAeQD+IZXjfgC/mR5/AYCvA3gUiSozNuTndAmAO0YhS9refenfAytjdBTjJG33AgB702f0eQA7RjRmJwA8A2CbOjYKOd4L4NvpeL0JwNggY8Td0BwOR6ngHhkOh6NU8EnP4XCUCj7pORyOUsEnPYfDUSr4pOdwOEoFn/QcDkep4JOew+EoFf4J/pDI3xeHhKcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 测试切片的方法\n",
    "img_01, label_01 = SVHDataset('Datasets/mchar_train/*.png', 'Datasets/mchar_train.json')[np.random.choice(range(30000))]\n",
    "plt.figure(figsize=(5, 5))\n",
    "plt.title('Label is %s' % label_01)\n",
    "plt.imshow(img_01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "EPOCH = 10\n",
    "BATCH_SIZE = 30\n",
    "USE_CUDA = True\n",
    "\n",
    "train_loader = ud.DataLoader(\n",
    "    dataset=SVHDataset('Datasets/mchar_train/*.png', 'Datasets/mchar_train.json', data_transforms['train']),\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    # num_workers 在windows上报错 设置改为 0\n",
    "    num_workers= (0 if sys.platform.startswith('win') else 10)\n",
    ")\n",
    "val_loader = ud.DataLoader(\n",
    "    dataset=SVHDataset('Datasets/mchar_val/*.png', 'Datasets/mchar_val.json', data_transforms['val']),\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=False,\n",
    "    # num_workers 在windows上报错 设置改为 0\n",
    "    num_workers= (0 if sys.platform.startswith('win') else 10)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型结构定义"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import models\n",
    "\n",
    "# 初始化构建线性网络\n",
    "class SVHN_Model(torch.nn.Module):\n",
    "    def __init__(self, path=None):\n",
    "        super(SVHN_Model, self).__init__()\n",
    "        if path is None:\n",
    "            model_conv = models.resnet18(pretrained=True)\n",
    "        else:\n",
    "            model_conv = models.resnet18(pretrained=False)\n",
    "            model_conv.load_state_dict(torch.load(path))\n",
    "        model_conv.avgpool = torch.nn.AdaptiveAvgPool2d(1)\n",
    "        model_conv = torch.nn.Sequential(*list(model_conv.children())[:-1])\n",
    "        self.cnn = model_conv\n",
    "        # 每个字符有11中情况\n",
    "        self.fc1 = torch.nn.Linear(512, 11)\n",
    "        self.fc2 = torch.nn.Linear(512, 11)\n",
    "        self.fc3 = torch.nn.Linear(512, 11)\n",
    "        self.fc4 = torch.nn.Linear(512, 11)\n",
    "        self.fc5 = torch.nn.Linear(512, 11)\n",
    "    def forward(self, img):\n",
    "        # activation function for\n",
    "        # 容易过拟合导致准确度下降\n",
    "        feat = self.cnn(img)\n",
    "        feat = feat.view(feat.shape[0], -1)\n",
    "        feat = F.dropout2d(feat)\n",
    "        # 排除其他无关元素影响，只留正相关因素\n",
    "        c1 = self.fc1(feat)\n",
    "        c2 = self.fc2(feat)\n",
    "        c3 = self.fc3(feat)\n",
    "        c4 = self.fc4(feat)\n",
    "        c5 = self.fc5(feat)\n",
    "        return c1, c2, c3, c4, c5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 验证、训练、预测方法定义"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import reduce\n",
    "\n",
    "def train_def(train_loader, model, loss_func, optimizer):\n",
    "    \n",
    "    # 切换模型为训练模式\n",
    "    model.train()\n",
    "    train_loss = []\n",
    "\n",
    "    for step, (batch_x, batch_y) in enumerate(train_loader):\n",
    "        # train your data...\n",
    "        if USE_CUDA:\n",
    "            batch_x = batch_x.cuda()\n",
    "            # 将 float32 强制转换为 long\n",
    "            batch_y = batch_y.long().cuda()\n",
    "        predicate = model(batch_x)\n",
    "        # 对应个位置上的字符 -> [11情况概率] <=> [label真实值]\n",
    "        loss = reduce(lambda x, y: x + y, [loss_func(predicate[m], batch_y[:, m]) for m in range(batch_y.shape[1])])\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        train_loss.append(loss.item())\n",
    "    return np.array(train_loss)\n",
    "        \n",
    "def validate_def(val_loader, model, loss_func):\n",
    "    # 切换模型为预测模型\n",
    "    model.eval()\n",
    "    val_loss = []\n",
    "    # 不记录模型梯度信息\n",
    "    with torch.no_grad():\n",
    "        for step, (batch_x, batch_y) in enumerate(val_loader):\n",
    "            if USE_CUDA:\n",
    "                batch_x = batch_x.cuda()\n",
    "                # 将 float32 强制转换为 long\n",
    "                batch_y = batch_y.long().cuda()\n",
    "            predicate = model(batch_x)\n",
    "            # 叠加 loss_func\n",
    "            loss = reduce(lambda x, y: x + y, [loss_func(predicate[m], batch_y[:, m]) for m in range(batch_y.shape[1])])\n",
    "            val_loss.append(loss.item())\n",
    "    return np.array(val_loss)\n",
    "\n",
    "def predict_def(test_loader, model, tta=10):\n",
    "    \n",
    "    model.eval()\n",
    "    test_pred_tta, test_target_tta = None, None\n",
    "\n",
    "    # TTA 次数\n",
    "    for _ in range(tta):\n",
    "        test_pred, test_target = [], []\n",
    "        with torch.no_grad():\n",
    "            for step, (batch_x, batch_y) in enumerate(test_loader):\n",
    "                if USE_CUDA:\n",
    "                    batch_x = batch_x.cuda()\n",
    "                    batch_y = batch_y.long().cuda()\n",
    "                predicate_y = model(batch_x)\n",
    "                output = torch.stack(predicate_y, dim=1)\n",
    "                # 最大概率的索引值\n",
    "                output = torch.argmax(output, dim=2)\n",
    "                test_pred.append(output)\n",
    "                test_target.append(batch_y)\n",
    "                \n",
    "        test_pred, test_target = torch.cat(test_pred), torch.cat(test_target)\n",
    "        \n",
    "        if test_pred_tta is None:\n",
    "            test_pred_tta, test_target_tta = test_pred, test_target\n",
    "        else:\n",
    "            test_pred_tta += test_pred\n",
    "            test_target_tta += test_target\n",
    "    return test_pred_tta, test_target_tta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型执行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, Train loss: 4.03201458311081 \t Val loss: 3.7035960837038693\n",
      "Val Acc 0.312\n",
      "Find better model in Epoch 0, saving model.\n",
      "Epoch: 1, Train loss: 2.5711640647649765 \t Val loss: 3.3322740825350414\n",
      "Val Acc 0.3872\n",
      "Find better model in Epoch 1, saving model.\n",
      "Epoch: 2, Train loss: 2.173556577920914 \t Val loss: 2.988036193176658\n",
      "Val Acc 0.4479\n",
      "Find better model in Epoch 2, saving model.\n",
      "Epoch: 3, Train loss: 1.9228811027407646 \t Val loss: 2.998298345568651\n",
      "Val Acc 0.452\n",
      "Epoch: 4, Train loss: 1.7466339838504792 \t Val loss: 2.753334600411489\n",
      "Val Acc 0.4978\n",
      "Find better model in Epoch 4, saving model.\n",
      "Epoch: 5, Train loss: 1.615467549264431 \t Val loss: 2.627869630288221\n",
      "Val Acc 0.5097\n",
      "Find better model in Epoch 5, saving model.\n"
     ]
    }
   ],
   "source": [
    "model = SVHN_Model()  # define the network\n",
    "if USE_CUDA:\n",
    "    model = model.cuda()\n",
    "\n",
    "# 开启训练模式\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(),lr=0.001)\n",
    "# the target label is NOT an one-hotted\n",
    "loss_func = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "best_loss = 15\n",
    "\n",
    "train_loss_list = []\n",
    "val_loss_list = []\n",
    "val_char_acc_list = []\n",
    "\n",
    "for epoch in range(EPOCH):\n",
    "    \n",
    "    # 损失得分\n",
    "    train_loss = train_def(train_loader, model, loss_func, optimizer)\n",
    "    val_loss = validate_def(val_loader, model, loss_func)\n",
    "    \n",
    "    # 预测值结果与真实值比较关联\n",
    "    val_predict_label, val_target_label = predict_def(val_loader, model, 1)\n",
    "    val_label_pred = np.array([''.join(map(lambda x: str(x.item()), labels[labels!=10])) for labels in val_predict_label])\n",
    "    val_label_target = np.array([''.join(map(lambda x: str(x.item()), labels[labels!=10])) for labels in val_target_label])\n",
    "    \n",
    "    # score 评价得分\n",
    "    val_char_acc = np.sum(val_label_pred == val_label_target) / len(val_label_target)\n",
    "    \n",
    "    #将值添加到list\n",
    "    train_loss_list.append(train_loss)\n",
    "    val_loss_list.append(val_loss)\n",
    "    val_char_acc_list.append(val_char_acc)\n",
    "    \n",
    "    print('Epoch: {0}, Train loss: {1} \\t Val loss: {2}'.format(epoch, np.mean(train_loss), np.mean(val_loss)))\n",
    "    print('Val Acc', val_char_acc)\n",
    "    \n",
    "    # 记录下验证集最佳精度\n",
    "    if np.mean(val_loss) < best_loss:\n",
    "        best_loss = np.mean(val_loss)\n",
    "        print('Find better model in Epoch {0}, saving model.'.format(epoch))\n",
    "        # 保存模型参数\n",
    "        torch.save(model.state_dict(), './model.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "sns.set_style()\n",
    "train_loss = np.hstack(np.array(train_loss_list))\n",
    "print(train_loss.shape)\n",
    "plt.figure(figsize=(18, 5))\n",
    "plt.subplot(1, 3, 1)\n",
    "plt.plot(train_loss)\n",
    "plt.xlabel('step')\n",
    "plt.ylabel('train loss')\n",
    "plt.title('Train loss from step')\n",
    "plt.subplot(1, 3, 2)\n",
    "val_loss = np.hstack(np.array(val_loss_list))\n",
    "plt.plot(val_loss)\n",
    "plt.xlabel('step')\n",
    "plt.ylabel('Val loss')\n",
    "plt.title('Val loss from step')\n",
    "plt.subplot(1, 3, 3)\n",
    "plt.plot(val_char_acc_list)\n",
    "plt.xlabel('epoch')\n",
    "plt.ylabel('Accuracy Value')\n",
    "plt.title('Val Accuracy from epoch')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
